{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook documents the procedure for training the single-label models during the 2023 TRAM effort.\n",
    "\n",
    "The `bootstrap-training-data` file contains the annotations that existed prior, as well as the annotations that were produced during the 2023 effort."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>network traffic communicates over a raw socket.</td>\n",
       "      <td>T1095</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>has the ability to set file attributes to hidden.</td>\n",
       "      <td>T1564.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>searches for files that are 60mb and less and ...</td>\n",
       "      <td>T1083</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Attackers can use legitimate domains that are ...</td>\n",
       "      <td>T1090</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>has registered two registry keys for shim data...</td>\n",
       "      <td>T1112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11125</th>\n",
       "      <td>usually the encrypted payload</td>\n",
       "      <td>T1027</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11126</th>\n",
       "      <td>includes a capability to modify the Beacon pay...</td>\n",
       "      <td>T1027</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11127</th>\n",
       "      <td>has collected the username of the victim system.</td>\n",
       "      <td>T1033</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11128</th>\n",
       "      <td>has given malware the same name as an existing...</td>\n",
       "      <td>T1036.005</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11129</th>\n",
       "      <td>After the strings are decrypted</td>\n",
       "      <td>T1140</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>11130 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                    text      label\n",
       "0        network traffic communicates over a raw socket.      T1095\n",
       "1      has the ability to set file attributes to hidden.  T1564.001\n",
       "2      searches for files that are 60mb and less and ...      T1083\n",
       "3      Attackers can use legitimate domains that are ...      T1090\n",
       "4      has registered two registry keys for shim data...      T1112\n",
       "...                                                  ...        ...\n",
       "11125                      usually the encrypted payload      T1027\n",
       "11126  includes a capability to modify the Beacon pay...      T1027\n",
       "11127   has collected the username of the victim system.      T1033\n",
       "11128  has given malware the same name as an existing...  T1036.005\n",
       "11129                    After the strings are decrypted      T1140\n",
       "\n",
       "[11130 rows x 2 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import json\n",
    "\n",
    "with open('../data/training/bootstrap-training-data.json') as f:\n",
    "    data_json = json.loads(f.read())\n",
    "\n",
    "data = pd.DataFrame(\n",
    "    [\n",
    "        {'text': row['text'], 'label': row['mappings'][0]['attack_id']}\n",
    "        for row in data_json['sentences']\n",
    "        if len(row['mappings']) > 0\n",
    "    ]\n",
    ")\n",
    "\n",
    "data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then load the model and move it to the GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import transformers\n",
    "import torch\n",
    "\n",
    "mode: 'bert or gpt' = 'bert'\n",
    "cuda = torch.device('cuda')\n",
    "\n",
    "if mode == 'bert':\n",
    "    model = transformers.BertForSequenceClassification.from_pretrained(\n",
    "        \"allenai/scibert_scivocab_uncased\",\n",
    "        num_labels=data['label'].nunique(),\n",
    "        output_attentions=False,\n",
    "        output_hidden_states=False,\n",
    "    )\n",
    "    tokenizer = transformers.BertTokenizer.from_pretrained(\"allenai/scibert_scivocab_uncased\", max_length=512)\n",
    "elif mode == 'gpt':\n",
    "    model = transformers.GPT2ForSequenceClassification.from_pretrained(\n",
    "        \"gpt2\",\n",
    "        num_labels=data['label'].nunique(),\n",
    "        output_attentions=False,\n",
    "        output_hidden_states=False,\n",
    "    )\n",
    "    tokenizer = transformers.GPT2Tokenizer.from_pretrained(\"gpt2\", max_length=512)\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "else:\n",
    "    raise ValueError(f\"mode must be one of bert or gpt, but is {mode = !r}\")\n",
    "\n",
    "model.train().to(cuda)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will represent the labels using one hot encoding.\n",
    "\n",
    "The `apply_attention_mask` function returns an attention mask (which is a tensor) where the element for every non-padding token is `1`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import OneHotEncoder as OHE\n",
    "\n",
    "encoder = OHE(sparse_output=False)\n",
    "encoder.fit(data[['label']])\n",
    "\n",
    "def tokenize(samples: 'list[str]'):\n",
    "    return tokenizer(samples, return_tensors='pt', padding='max_length', truncation=True, max_length=512).input_ids\n",
    "\n",
    "def load_data(x, y, batch_size=10):\n",
    "    x_len, y_len = x.shape[0], y.shape[0]\n",
    "    assert x_len == y_len\n",
    "    for i in range(0, x_len, batch_size):\n",
    "        slc = slice(i, i + batch_size)\n",
    "        yield x[slc].to(cuda), y[slc].to(cuda)\n",
    "\n",
    "def apply_attention_mask(x):\n",
    "    return x.ne(tokenizer.pad_token_id).to(int)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  102,   434,   501,  ...,     0,     0,     0],\n",
       "        [  102,  4199, 14562,  ...,     0,     0,     0],\n",
       "        [  102,   147,  3901,  ...,     0,     0,     0],\n",
       "        ...,\n",
       "        [  102, 10740,   111,  ...,     0,     0,     0],\n",
       "        [  102,   434, 13037,  ...,     0,     0,     0],\n",
       "        [  102,   121,   993,  ...,     0,     0,     0]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train, test = train_test_split(data, test_size=.2, stratify=data['label'])\n",
    "\n",
    "x_train = tokenize(train['text'].tolist())\n",
    "x_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train = torch.Tensor(encoder.transform(train[['label']]))\n",
    "y_train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The hyperparameters shown here are those that we used, including the number of epochs and batch size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch.optim import AdamW\n",
    "from tqdm import tqdm\n",
    "from statistics import mean\n",
    "\n",
    "optim = AdamW(model.parameters(), lr=2e-5, eps=1e-8)\n",
    "\n",
    "for epoch in range(6):\n",
    "    epoch_losses = []\n",
    "    for x, y in tqdm(load_data(x_train, y_train, batch_size=10)):\n",
    "        model.zero_grad()\n",
    "        out = model(x, attention_mask=apply_attention_mask(x), labels=y)\n",
    "        epoch_losses.append(out.loss.item())\n",
    "        out.loss.backward()\n",
    "        optim.step()\n",
    "    print(f\"epoch {epoch + 1} loss: {mean(epoch_losses)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "model.eval()\n",
    "\n",
    "preds = []\n",
    "batch_size = 20\n",
    "\n",
    "x_test = tokenize(test['text'].tolist())\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(0, x_test.shape[0], batch_size):\n",
    "        x = x_test[i : i + batch_size].to(cuda)\n",
    "        out = model(x, attention_mask=apply_attention_mask(x))\n",
    "        preds.extend(out.logits.to('cpu'))\n",
    "\n",
    "predicted_labels = (\n",
    "    encoder.inverse_transform(\n",
    "        F.one_hot(\n",
    "            torch.vstack(preds).softmax(-1).argmax(-1),\n",
    "            num_classes=50\n",
    "        )\n",
    "        .numpy()\n",
    "    )\n",
    "    .reshape(-1)\n",
    ")\n",
    "\n",
    "predicted_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_recall_fscore_support as calculate_score\n",
    "\n",
    "predicted = list(predicted_labels)\n",
    "actual = test['label'].tolist()\n",
    "\n",
    "labels = sorted(data['label'].unique())\n",
    "\n",
    "scores = calculate_score(actual, predicted, labels=labels)\n",
    "\n",
    "scores_df = pd.DataFrame(scores).T\n",
    "scores_df.columns = ['P', 'R', 'F1', '#']\n",
    "scores_df.index = labels\n",
    "scores_df.loc['(micro)'] = calculate_score(actual, predicted, average='micro', labels=labels)\n",
    "scores_df.loc['(macro)'] = calculate_score(actual, predicted, average='macro', labels=labels)\n",
    "\n",
    "scores_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
